import torch.nn as nn


class MLP(nn.Module):
    """全连接网络"""

    def __init__(self):
        super(MLP, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(2, 1024),
            nn.Tanh(),
            nn.Linear(1024, 512),
            # nn.GELU(),
            nn.Linear(512, 128),
            nn.Linear(128, 2),
        )

    def forward(self, x):
        out = self.fc(x)
        return out


def myMLP():
    return MLP()
